# -*- coding: UTF-8 -*-  
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
tf.compat.v1.disable_eager_execution()
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.keras import layers, Model, optimizers

from agents.dqn import DQN
from agents.sfdqn import SFDQN
from agents.dqn_ft import DQN_FT
from agents.sfdqn_ft import SFDQN_FT
from agents.buffer import ReplayBuffer
from features.deep import DeepSF
from features.deep_ft import DeepSF_FT
from tasks.reacher import Reacher
from utils.config import parse_config_file

# read parameters from config file
config_params = parse_config_file('reacher.cfg')

gen_params = config_params['GENERAL']
n_samples = gen_params['n_samples']

task_params = config_params['TASK']
goals = task_params['train_targets']
test_goals = task_params['test_targets']
all_goals = goals + test_goals
    
agent_params = config_params['AGENT']
dqn_params = config_params['DQN']
sfdqn_params = config_params['SFDQN']


# tasks
def generate_tasks(include_target):
    train_tasks = [Reacher(all_goals, i, include_target) for i in range(1)] # CHANGED len(goals) --> 1
    test_tasks = [Reacher(all_goals, i + len(goals), include_target) for i in range(len(test_goals))]
    return train_tasks, test_tasks

# generate tasks for fine tuning, test tasks matching the train tasks (taks 5 through 12)
def generate_ft_tasks(include_target):
    train_tasks = [Reacher(all_goals, i + len(goals), include_target) for i in range(len(test_goals))]
    test_tasks = [Reacher(all_goals, i + len(goals), include_target) for i in range(len(test_goals))]
    return train_tasks, test_tasks


# keras model
def dqn_model_lambda():
    keras_params = dqn_params['keras_params']
    x = y = layers.Input(6)
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9, activation='linear')(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(sgd, 'mse')
    return model


# keras model for the SF
def sf_model_lambda(x):
    n_features = len(all_goals)
    keras_params = sfdqn_params['keras_params']
    y = x
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9 * n_features, activation='linear')(y)
    y = layers.Reshape((9, n_features))(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(sgd, 'mse')
    return model

def test_model_freeze(): # working
    # build DQN model
    model = dqn_model_lambda()

    model.summary()

    # try to freeze the last layer
    model.layers[-1].trainable = False

    model.summary()

def train():
    # ----------------------------------- SFDQN ------------------------------------------ 
       
    # # build SFDQN    
    # print('building SFDQN')
    # deep_sf = DeepSF(keras_model_handle=sf_model_lambda, **sfdqn_params)
    # sfdqn = SFDQN(deep_sf=deep_sf, buffer=ReplayBuffer(sfdqn_params['buffer_params']),
    #               **sfdqn_params, **agent_params)
    
    # # train SFDQN
    # print('training SFDQN')
    # train_tasks, test_tasks = generate_tasks(False)
    # train_samples = 8*n_samples
    # sfdqn_perf = sfdqn.train(train_tasks, train_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'])
    # # save pretrained dqn for later use
    # print('saving SFDQN ...')
    # sfdqn.sf.psi[-1][0].save(f'pretrained_sfdqn_{train_samples}')
    # print('Done')
    # # save the dqn performance
    # np.save(f'sfdqn_perf_{train_samples}', sfdqn_perf)  
      
    # ----------------------------------- DQN ------------------------------------------
    
    # # build DQN
    # print('building DQN')
    # dqn = DQN(model_lambda=dqn_model_lambda, buffer=ReplayBuffer(dqn_params['buffer_params']),
    #           **dqn_params, **agent_params)
    
    # # training DQN
    # print('training DQN')
    # train_tasks, test_tasks = generate_tasks(True)
    # dqn_perf = dqn.train(train_tasks, train_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'])
    # # save pretrained dqn for later use
    # print('saving DQN ...')
    # dqn.Q.save(f'pretrained_dqn_{train_samples}')
    # print('Done')
    # # save the dqn performance
    # np.save(f'dqn_perf_{train_samples}', dqn_perf)


    # fine tuning to new tasks
    print('fine tuning to new tasks')

    # ------------------------------- DQN FT -------------------------------------------------------------------------
    
    # print('fine tuning DQN')
    # # loading pretrained model
    # # from previous run
    # pretrained_model = tf.keras.models.load_model("pretrained_dqn_400000")
    # # # .. or from current run
    # # pretrained_model = dqn_model_lambda()
    # # pretrained_model.set_weights(dqn.Q.get_weights())

    # # create dqn_ft agent
    # dqn_ft = DQN_FT(model_lambda=dqn_model_lambda, buffer=ReplayBuffer(dqn_params['buffer_params']), pretrained_model=pretrained_model,
    #           **dqn_params, **agent_params)
    # train_tasks_ft, test_tasks_ft = generate_ft_tasks(True)
    # dqn_ft_perf = dqn_ft.train(train_tasks_ft, n_samples, test_tasks=test_tasks_ft, n_test_ev=agent_params['n_test_ev'])
    # # save the dqn_ft performance
    # np.save('dqn_ft_perf_400000', dqn_ft_perf)

    # -------------------------------- SFDQN FT -----------------------------------------------------------------------
    
    print('fine tuning SFDQN')
    # loading pretrained model
    # from previous run
    pretrained_sf_model = tf.keras.models.load_model("pretrained_sfdqn_800000")
    # set whether retraining or fine tuning last layer
    retrain=True
    # suffix to add to identify the model/lof file/performance file
    suff = 'rt_800000'

    deep_sf_ft = DeepSF_FT(keras_model_handle=sf_model_lambda, pretrained_model=pretrained_sf_model, retrain=retrain, **sfdqn_params)
    sfdqn_ft = SFDQN_FT(deep_sf=deep_sf_ft, buffer=ReplayBuffer(sfdqn_params['buffer_params']),
                  **sfdqn_params, **agent_params)
    
    # fine tune SFDQN
    train_tasks_ft, test_tasks_ft = generate_ft_tasks(False)
    sfdqn_ft_perf = sfdqn_ft.train(train_tasks_ft, n_samples, test_tasks=test_tasks_ft, n_test_ev=agent_params['n_test_ev'])
    # save the fine tuned sfdqn performance
    np.save(f'sfdqn_perf_{suff}', sfdqn_ft_perf)    

    # -------------------------------------------------------------------------------------------------------------------


def plot_fig():
    dqn_ft_perf = np.load('dqn_perf_ft_400000.npy')
    sfdqn_ft_perf = np.load('sfdqn_perf_ft_800000.npy')
    sfdqn_retrain_perf = np.load('sfdqn_perf_rt_800000.npy')
    # smooth data    
    def smooth(y, box_pts):
        return np.convolve(y, np.ones(box_pts) / box_pts, mode='same')

    sfdqn_ft_perf = smooth(sfdqn_ft_perf, 10)[:-5] # CHANGED
    dqn_ft_perf = smooth(dqn_ft_perf, 10)[:-5]
    sfdqn_retrain_perf = smooth(sfdqn_retrain_perf, 10)[:-5]
    x = np.linspace(0, 8, dqn_ft_perf.size)
    
    # reporting progress
    ticksize = 14
    textsize = 18
    plt.rc('font', size=textsize)  # controls default text sizes
    plt.rc('axes', titlesize=textsize)  # fontsize of the axes title
    plt.rc('axes', labelsize=textsize)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('legend', fontsize=ticksize)  # legend fontsize

    plt.figure(figsize=(8, 6))
    ax = plt.gca()
    ax.plot(x, sfdqn_ft_perf, label='SFDQN_FT') #CHANGED
    ax.plot(x, sfdqn_retrain_perf, label='SFDQN_RT') #CHANGED
    ax.plot(x, dqn_ft_perf, label='DQN_FT')
    plt.xlabel('testing task index')
    plt.ylabel('averaged test episode reward')
    plt.title('Testing Reward for each Test Tasks')
    plt.tight_layout()
    plt.legend(frameon=False)
    plt.savefig('figures/dqn_400000_sfdqn_ft_comp_800000.png')



# DEBUG
# test_model_freeze()

# PLOT FIGURES
plot_fig()

# RUN
# train()
